# Modified from RFdiffusion (https://github.com/RosettaCommons/RFdiffusion)
# Original license: BSD License
#
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import sys

import mindspore as ms
import numpy as np

from . import potentials as potentials


def make_contact_matrix(nchain, intra_all=False, inter_all=False, contact_string=None):
    """
    Calculate a matrix of inter/intra chain contact indicators

    Parameters:
        nchain (int, required): How many chains are in this design

        contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains
            '!' denotes repulsive, '&' denotes attractive
    """
    alphabet = [a for a in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
    letter2num = {a: i for i, a in enumerate(alphabet)}

    contacts = np.zeros((nchain, nchain))

    # intra_all - everything on the diagonal has contact potential
    if intra_all:
        contacts[np.arange(nchain), np.arange(nchain)] = 1

    # inter all - everything off the diagonal has contact potential
    if inter_all:
        mask2d = np.full_like(contacts, False)
        for i in range(len(contacts)):
            for j in range(len(contacts)):
                if i != j:
                    mask2d[i, j] = True

        contacts[mask2d.astype(bool)] = 1

    # custom contacts/repulsions from user
    if contact_string != None:
        contact_list = contact_string.split(",")
        for c in contact_list:
            assert len(c) == 3
            i, j = letter2num[c[0]], letter2num[c[2]]

            symbol = c[1]

            assert symbol in ["!", "&"]
            if symbol == "!":
                contacts[i, j] = -1
                contacts[j, i] = -1
            else:
                contacts[i, j] = 1
                contacts[j, i] = 1

    return contacts


def calc_nchains(symbol, components=1):
    """
    Calculates number of chains for given symmetry
    """
    S = symbol.lower()

    if S.startswith("c"):
        return int(S[1:]) * components
    elif S.startswith("d"):
        return 2 * int(S[1:]) * components
    elif S.startswith("o"):
        raise NotImplementedError()
    elif S.startswith("t"):
        return 12 * components
    else:
        raise RuntimeError("Unknown symmetry symbol ", S)


class PotentialManager:
    """
    Class to define a set of potentials from the given config object and to apply all of the specified potentials
    during each cycle of the inference loop.

    Author: NRB
    """

    def __init__(
        self,
        potentials_config,
        ppi_config,
        diffuser_config,
        inference_config,
        hotspot_0idx,
        binderlen,
    ):

        self.potentials_config = potentials_config
        self.ppi_config = ppi_config
        self.inference_config = inference_config

        self.guide_scale = potentials_config.guide_scale
        self.guide_decay = potentials_config.guide_decay

        if potentials_config.guiding_potentials is None:
            setting_list = []
        else:
            setting_list = [
                self.parse_potential_string(potstr)
                for potstr in potentials_config.guiding_potentials
            ]

        # PPI potentials require knowledge about the binderlen which may be detected at runtime
        # This is a mechanism to still allow this info to be used in potentials - NRB
        if binderlen > 0:
            binderlen_update = {"binderlen": binderlen}
            hotspot_res_update = {"hotspot_res": hotspot_0idx}

            for setting in setting_list:
                if setting["type"] in potentials.require_binderlen:
                    setting.update(binderlen_update)

        self.potentials_to_apply = self.initialize_all_potentials(setting_list)
        self.T = diffuser_config.T

    def is_empty(self):
        """
        Check whether this instance of PotentialManager actually contains any potentials
        """

        return len(self.potentials_to_apply) == 0

    def parse_potential_string(self, potstr):
        """
        Parse a single entry in the list of potentials to be run to a dictionary of settings for that potential.

        An example of how this parsing is done:
        'setting1:val1,setting2:val2,setting3:val3' -> {setting1:val1,setting2:val2,setting3:val3}
        """

        setting_dict = {
            entry.split(":")[0]: entry.split(":")[1] for entry in potstr.split(",")
        }

        for key in setting_dict:
            if not key == "type":
                setting_dict[key] = float(setting_dict[key])

        return setting_dict

    def initialize_all_potentials(self, setting_list):
        """
        Given a list of potential dictionaries where each dictionary defines the configurations for a single potential,
        initialize all potentials and add to the list of potentials to be applies
        """

        to_apply = []

        for potential_dict in setting_list:
            assert (
                potential_dict["type"] in potentials.implemented_potentials
            ), f'potential with name: {potential_dict["type"]} is not one of the implemented potentials: {potentials.implemented_potentials.keys()}'

            kwargs = {k: potential_dict[k] for k in potential_dict.keys() - {"type"}}

            # symmetric oligomer contact potential args
            if self.inference_config.symmetry:

                num_chains = calc_nchains(
                    symbol=self.inference_config.symmetry, components=1
                )  # hard code 1 for now
                contact_kwargs = {
                    "nchain": num_chains,
                    "intra_all": self.potentials_config.olig_intra_all,
                    "inter_all": self.potentials_config.olig_inter_all,
                    "contact_string": self.potentials_config.olig_custom_contact,
                }
                contact_matrix = make_contact_matrix(**contact_kwargs)
                kwargs.update({"contact_matrix": contact_matrix})

            to_apply.append(
                potentials.implemented_potentials[potential_dict["type"]](**kwargs)
            )

        return to_apply

    def compute_all_potentials(self, xyz):
        """
        This is the money call. Take the current sequence and structure information and get the sum of all of the potentials that are being used
        """

        potential_list = [
            potential.compute(xyz) for potential in self.potentials_to_apply
        ]
        potential_stack = ms.mint.stack(potential_list, dim=0)

        return ms.mint.sum(potential_stack, dim=0)

    def get_guide_scale(self, t):
        """
        Given a timestep and a decay type, get the appropriate scale factor to use for applying guiding potentials

        Inputs:

            t (int, required):          The current timestep

        Output:

            scale (int):                The scale factor to use for applying guiding potentials

        """

        implemented_decay_types = {
            "constant": lambda t: self.guide_scale,
            # Linear interpolation with y2: 0, y1: guide_scale, x2: 0, x1: T, x: t
            "linear": lambda t: t / self.T * self.guide_scale,
            "quadratic": lambda t: t**2 / self.T**2 * self.guide_scale,
            "cubic": lambda t: t**3 / self.T**3 * self.guide_scale,
        }

        if self.guide_decay not in implemented_decay_types:
            sys.exit(
                f"decay_type must be one of {implemented_decay_types.keys()}. Received decay_type={self.guide_decay}. Exiting."
            )

        return implemented_decay_types[self.guide_decay](t)
